Conversation
There was a problem hiding this comment.
Pull request overview
Adds first-class RL + SSM (RLSSM) support to HSSM by introducing a new RLSSM model that builds a differentiable PyTensor Op from an annotated JAX SSM log-likelihood and plugs it into the existing distribution-building pipeline.
Changes:
- Introduces
RLSSMmodel class plus RL utilityvalidate_balanced_panel. - Extends configuration via
RLSSMConfig.ssm_logp_funcand exposes RLSSM in the public API. - Adds test coverage for RLSSM initialization/model build and updates RLSSMConfig validation tests.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/hssm/rl/rlssm.py |
New RLSSM model implementation integrating RL likelihood Op into HSSMBase. |
src/hssm/rl/utils.py |
Adds balanced-panel validation helper for RLSSM datasets. |
src/hssm/rl/__init__.py |
RL subpackage exports for RLSSM and utilities. |
src/hssm/config.py |
Adds ssm_logp_func to RLSSMConfig and validates presence. |
src/hssm/__init__.py |
Exposes RLSSM / RLSSMConfig at top-level. |
tests/test_rlssm.py |
New end-to-end-ish RLSSM tests (init, model build, balanced panel, smoke sampling). |
tests/test_rlssm_config.py |
Updates RLSSMConfig tests to include the new required field. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
validate_balanced_panel only checks equal trial counts, but the RL likelihood builder reshapes the row order into (n_participants, n_trials, ...) (see make_rl_logp_func), which assumes each participant’s trials are in one contiguous block (and usually in-trial order). With interleaved participants, the panel can be “balanced” yet produce a silently incorrect likelihood. Consider validating contiguity (each participant appears in exactly one run of length n_trials) and/or sorting by participant_col (+ an optional trial_col if present) before returning (n_participants, n_trials).
| return int(len(counts)), int(counts.iloc[0]) | |
| # Ensure that each participant's trials form a single contiguous block | |
| # of rows of length n_trials. This is required because downstream code | |
| # reshapes the data into (n_participants, n_trials, ...) based on row | |
| # order, assuming no interleaving across participants. | |
| n_trials = int(counts.iloc[0]) | |
| # Identify contiguous "blocks" of identical participant IDs. | |
| blocks = data[participant_col].ne(data[participant_col].shift()).cumsum() | |
| block_counts = data.groupby([participant_col, blocks]).size() | |
| # Each participant must appear in exactly one block, and that block | |
| # must have length n_trials. | |
| blocks_per_participant = block_counts.groupby(level=0).size() | |
| invalid_multi_blocks = blocks_per_participant[blocks_per_participant != 1] | |
| invalid_block_sizes = block_counts[block_counts != n_trials] | |
| if not invalid_multi_blocks.empty or not invalid_block_sizes.empty: | |
| raise ValueError( | |
| "Data must be ordered so that each participant's trials appear in " | |
| "a single contiguous block of rows of length n_trials. " | |
| "Participants with non-contiguous or incorrectly sized blocks " | |
| f"were found. Consider sorting your data by '{participant_col}' " | |
| "and, if available, by a trial index column before building the " | |
| "RL likelihood." | |
| ) | |
| return int(len(counts)), n_trials |
| "Please provide the correct participant column name via " | ||
| "`participant_col`." | ||
| ) | ||
|
|
There was a problem hiding this comment.
groupby(participant_col) drops NaN participant IDs by default, which can make n_participants/n_trials incorrect without an explicit error. Consider adding a check like data[participant_col].isna().any() and raising a clear ValueError if participant IDs are missing.
| # Ensure there are no missing participant IDs, since groupby will drop NaNs | |
| # silently, which would make n_participants / n_trials incorrect. | |
| if data[participant_col].isna().any(): | |
| raise ValueError( | |
| f"Column '{participant_col}' contains missing values. " | |
| "Please fill or remove rows with missing participant IDs before " | |
| "calling validate_balanced_panel." | |
| ) |
…ble and has required attributes
… callable and properly annotated
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/hssm/rl/rlssm.py
Outdated
| ) | ||
|
|
||
| # Rearrange data so missing rows come first (no-op when missing_data=False). | ||
| self.data = _rearrange_data(self.data) | ||
|
|
There was a problem hiding this comment.
_rearrange_data(self.data) changes row order, but the RL logp Op reshapes trials purely by row order into (n_participants, n_trials, ...). If any rows are moved (e.g., when missing_data=True and rt == -999), this will break per-participant trial sequences and invalidate the RL learning dynamics. Since missing-data networks are not supported for RLSSM, consider raising an explicit error when missing_data/deadline handling is requested (or implement a participant-wise rearrangement that preserves within-subject order).
| counts = data.groupby(participant_col).size() | ||
| if counts.nunique() != 1: | ||
| raise ValueError( | ||
| "Data must form balanced panels: all participants must have the " | ||
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
validate_balanced_panel() only checks equal trial counts via groupby().size(), but it does not validate that rows are ordered/grouped by participant. The RL likelihood builder (make_rl_logp_func) reshapes arrays with .reshape(n_participants, n_trials, -1) based purely on row order, so interleaved participant rows will silently mix subjects/trials and produce an incorrect likelihood. Consider either (a) enforcing contiguous blocks per participant (and optionally stable-sorting by participant_col + a trial index column if available) or (b) returning a sorted copy of the data and using that downstream.
…preserve trial sequence integrity
… RLSSM initialization
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/hssm/rl/rlssm.py
Outdated
| # All RLSSM parameters are treated as trialwise: the Op expects arrays of | ||
| # length n_total_trials for every parameter, and make_distribution.logp | ||
| # broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly. | ||
| params_is_trialwise = [ | ||
| True for param_name in self.params if param_name != "p_outlier" | ||
| ] | ||
|
|
||
| extra_fields_data = ( | ||
| None | ||
| if not self.extra_fields | ||
| else [deepcopy(self.data[field].values) for field in self.extra_fields] | ||
| ) | ||
|
|
||
| assert self.list_params is not None, "list_params should be set" | ||
| # self.loglik was set to the pytensor Op built in __init__; cast to | ||
| # narrow the inherited union type so make_distribution's type-checker | ||
| # accepts it without a runtime penalty. | ||
| loglik_op = cast("Callable[..., Any] | Op", self.loglik) | ||
| return make_distribution( | ||
| rv=self.model_name, | ||
| loglik=loglik_op, | ||
| list_params=self.list_params, | ||
| bounds=self.bounds, | ||
| lapse=self.lapse, | ||
| extra_fields=extra_fields_data, | ||
| params_is_trialwise=params_is_trialwise, | ||
| ) |
There was a problem hiding this comment.
params_is_trialwise is derived from self.params (excluding p_outlier), but it is passed alongside list_params=self.list_params. If self.list_params includes p_outlier (common in HSSMBase), this makes params_is_trialwise shorter and potentially misaligned with list_params, which can cause incorrect broadcasting or length-check failures in make_distribution. Build params_is_trialwise from self.list_params in the same order, marking p_outlier as non-trialwise.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ary assertion for list_params
… for independent copies
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Build the differentiable pytensor Op from the annotated SSM function. | ||
| # This Op supersedes the loglik/loglik_kind workflow: it is passed as | ||
| # `loglik` to HSSMBase so Config.validate() is satisfied, and | ||
| # _make_model_distribution() uses it directly without any further wrapping. | ||
| # | ||
| # Pass copies of list_params / extra_fields so the closure inside | ||
| # make_rl_logp_func captures its own isolated list objects. HSSMBase will | ||
| # later append "p_outlier" to self.list_params (which is the SAME list | ||
| # object as `list_params` above), and that mutation must NOT be visible to | ||
| # the Op's _validate_args_length check at sampling time. | ||
| loglik_op = make_rl_logp_op( | ||
| ssm_logp_func=rlssm_config.ssm_logp_func, | ||
| n_participants=n_participants, | ||
| n_trials=n_trials, | ||
| data_cols=list(data_cols), | ||
| list_params=list(list_params), | ||
| extra_fields=list(extra_fields), | ||
| ) |
There was a problem hiding this comment.
RLSSM builds the Op exclusively from rlssm_config.ssm_logp_func (and its .computed metadata) but never uses rlssm_config.learning_process. Since learning_process is still a required RLSSMConfig field, this creates a confusing/fragile API where users can supply learning functions that are silently ignored (or diverge from ssm_logp_func.computed). Consider making learning_process optional/removing it from RLSSMConfig, or validating it matches (or populates) ssm_logp_func.computed so there is a single source of truth.
|
@krishnbera @AlexanderFengler @digicosmos86 Here is a first draft for the RLSSM class. |
| def __init__( | ||
| self, | ||
| data: pd.DataFrame, | ||
| rlssm_config: RLSSMConfig, |
There was a problem hiding this comment.
I think we should keep the two classes as similar to each other as possible, so I'd prefer model_config here
| data: pd.DataFrame, | ||
| rlssm_config: RLSSMConfig, | ||
| participant_col: str = "participant_id", | ||
| include: list[dict[str, Any] | Any] | None = None, |
There was a problem hiding this comment.
@AlexanderFengler include is a legacy naming convention from HDDM. However, to me it's kind of confusing now. Should we deprecate this for something clearer? We can use an alias for now with a deprecation warning and completely remove it in a future release
There was a problem hiding this comment.
What would you call it instead here?
We would want to make that change globally not just for this class I guess.
Either way, would do that as a separate PR.
| **kwargs: Any, | ||
| ) -> None: | ||
| # Validate config (ensures ssm_logp_func is present, etc.) | ||
| rlssm_config.validate() |
There was a problem hiding this comment.
Should we initiate the parent class first?
| # would scramble per-participant trial sequences and corrupt RL dynamics. | ||
| # Raise early so the user gets a clear message before model construction. | ||
| if missing_data is not False: | ||
| raise ValueError( |
There was a problem hiding this comment.
I think we are implementing it in the future. Maybe NotImplementedError for now?
src/hssm/rl/rlssm.py
Outdated
| # Build a ModelConfig so HSSMBase._build_model_config can apply the | ||
| # RLSSM-specific fields (response, list_params, choices, bounds, …). | ||
| # default_priors is an empty dict (no parameter-specific priors pre-set) | ||
| # so that the prior_settings="safe" mechanism in HSSMBase assigns | ||
| # sensible priors from bounds. Populating it with params_default scalar | ||
| # floats would fix every parameter as a constant, which is incorrect. | ||
| mc = ModelConfig( | ||
| response=(tuple(rlssm_config.response) if rlssm_config.response else None), | ||
| list_params=list_params, | ||
| choices=(tuple(rlssm_config.choices) if rlssm_config.choices else None), | ||
| default_priors={}, | ||
| bounds=rlssm_config.bounds or {}, | ||
| extra_fields=extra_fields if extra_fields else None, | ||
| backend="jax", # RLSSM always uses the JAX backend | ||
| ) |
There was a problem hiding this comment.
I wouldn't do this tbh. The purpose of inheritance is not to funnel sub-class functionalities into base-class functionalities. Rather subclass should expand base-class functionalities through overrides
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| counts = data.groupby(participant_col).size() | ||
| if counts.nunique() != 1: | ||
| raise ValueError( | ||
| "Data must form balanced panels: all participants must have the " | ||
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
validate_balanced_panel() will raise an IndexError on empty input because counts is empty and counts.iloc[0] is accessed. Please add an explicit empty-data check (e.g., if data.empty: raise ValueError(...)) so callers (notably RLSSM.__init__) get a clear, consistent ValueError instead of an internal indexing error.
…oved configuration handling
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| response: list[str] | None = field(default_factory=DEFAULT_SSM_OBSERVED_DATA.copy) | ||
| choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES | ||
| choices: list[int] | tuple[int, ...] | None = DEFAULT_SSM_CHOICES | ||
|
|
There was a problem hiding this comment.
choices was widened to list[int] | tuple[int, ...], but Config.update_choices() is still typed/docs as tuple[int, ...] and _build_model_config can pass a list[int]. To avoid inconsistent public typing (and future mypy confusion), consider updating update_choices (and any related docstrings/types) to accept list[int] | tuple[int, ...] and optionally normalize to a single internal representation (e.g., always store a tuple).
…nfig and DataValidatorMixin
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| data: pd.DataFrame, | ||
| rlssm_config: RLSSMConfig, | ||
| participant_col: str = "participant_id", | ||
| include: list[dict[str, Any] | Any] | None = None, |
There was a problem hiding this comment.
What would you call it instead here?
We would want to make that change globally not just for this class I guess.
Either way, would do that as a separate PR.
| ) | ||
| if deadline is not False: | ||
| raise ValueError( | ||
| "RLSSM does not support `deadline` handling. " |
There was a problem hiding this comment.
@krishnbera do we actually have a solution for this?
| """ | ||
| # Start with defaults | ||
| config = cls.config_class.from_defaults(model, loglik_kind) | ||
| # get_config_class is provided by Config/RLSSMConfig mixin through MRO |
There was a problem hiding this comment.
why does RLSSMConfig show up here in this file?
| "decision_process": "ddm", | ||
| "learning_process": {}, | ||
| "learning_process_loglik_kind": "blackbox", | ||
| "decision_process_loglik_kind": "analytical", |
There was a problem hiding this comment.
learning_process_loglik_kind not a valid concept.
This pull request introduces reinforcement learning sequential sampling model (RLSSM) support to the HSSM package. It adds a new
RLSSMclass, supporting configuration, likelihood construction, and data validation for RL+SSM models, and refines the configuration workflow to require a fully annotated log-likelihood function. The changes also improve pre-commit configuration and update the package's public API.Major features and changes:
1. RLSSM Model Integration
RLSSMclass insrc/hssm/rl/rlssm.pyto support models that combine reinforcement learning processes with sequential sampling models. This class builds a differentiable pytensor Op from an annotated JAX log-likelihood function and enforces strict data requirements for balanced panels.validate_balanced_panelinsrc/hssm/rl/utils.pyto ensure input data forms a balanced panel, which is required for RLSSM models.2. Configuration Enhancements
RLSSMConfiginsrc/hssm/config.pyto require anssm_logp_func(an annotated JAX SSM log-likelihood function), replacing the previousloglik/loglik_kindworkflow. Added runtime validation to ensure this function is callable and properly annotated. [1] [2] [3]from_rlssm_dictto accept a config dictionary and extractssm_logp_funcandmodel_namedirectly from it, simplifying model instantiation.3. Public API and Package Structure
RLSSMandRLSSMConfigin the package's public API viasrc/hssm/__init__.pyand created a newsrc/hssm/rl/__init__.pyfor RL-related exports. [1] [2] [3]4. Developer Experience
.pre-commit-config.yamlto exclude thetests/directory fromruffandmypychecks, streamlining development workflows.